ACKTR from scratch (low-level PyTorch) — CartPole-v1#
This notebook implements ACKTR-style policy optimization in low-level PyTorch:
Actor update uses a K-FAC preconditioner (approx. Fisher inverse) + trust-region clipping.
Critic is trained as a baseline (value function) with a simple first-order optimizer for stability.
We log training dynamics and visualize them with Plotly, including episodic reward progression.
Prereqs:
PyTorch
Gymnasium
Plotly
Theory reference: see 00_overview.ipynb in this folder.
Notebook roadmap#
Setup + environment
Actor/Critic networks
Rollout collection + GAE
K-FAC optimizer (Linear layers)
Training loop (ACKTR update)
Plotly diagnostics (reward + KL + losses)
Stable-Baselines ACKTR reference + hyperparameters
import random
import time
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
import gymnasium as gym
import torch
import torch.nn as nn
from torch.distributions import Categorical
pio.templates.default = 'plotly_white'
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)
print('NumPy', np.__version__)
print('Pandas', pd.__version__)
print('Plotly', plotly.__version__)
print('Gymnasium', gym.__version__)
print('Torch', torch.__version__)
NumPy 1.26.2
Pandas 2.1.3
Plotly 6.5.2
Gymnasium 1.1.1
Torch 2.7.0+cu126
# --- Reproducibility ---
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# Keep the implementation CPU-friendly and deterministic-ish.
DEVICE = torch.device('cpu')
print('DEVICE', DEVICE)
DEVICE cpu
# --- Run configuration ---
FAST_RUN = True # set False for a longer, smoother curve
# Environment
ENV_ID = 'CartPole-v1'
# Rollout / training
TOTAL_TIMESTEPS = 40_000 if FAST_RUN else 200_000
ROLLOUT_STEPS = 128
# Discounting / advantage
GAMMA = 0.99
GAE_LAMBDA = 0.95
# Loss weights
ENT_COEF = 0.00
# Critic optimizer
CRITIC_LR = 1e-3
# K-FAC / ACKTR knobs (actor)
ACTOR_LR = 0.10
KFAC_DAMPING = 0.03
KFAC_STATS_DECAY = 0.95
KFAC_CLIP = 0.01 # trust region / KL clip (see theory)
INVERSE_UPDATE_INTERVAL = 1
print('TOTAL_TIMESTEPS', TOTAL_TIMESTEPS)
TOTAL_TIMESTEPS 40000
1) Environment#
CartPole-v1 is a classic discrete-action benchmark:
state \(s \in \mathbb{R}^4\)
actions \(a \in \{0,1\}\)
reward \(r_t = 1\) per step until termination
It’s a good fit for a minimal ACKTR demonstration because the policy is a categorical distribution.
env = gym.make(ENV_ID)
obs_dim = int(env.observation_space.shape[0])
act_dim = int(env.action_space.n)
obs, _ = env.reset(seed=SEED)
print('obs_dim', obs_dim, 'act_dim', act_dim)
print('first obs', obs)
obs_dim 4 act_dim 2
first obs [ 0.0274 -0.0061 0.0359 0.0197]
2) Actor–critic parameterization#
We use two networks:
Actor: logits for a categorical policy \(\pi_\theta(a\mid s)\).
Critic: a value function baseline \(V_\phi(s)\).
The actor loss (policy gradient with entropy bonus) is:
The critic trains by regression to (bootstrapped) returns:
class Actor(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden_sizes=(64, 64)):
super().__init__()
layers = []
last = obs_dim
for h in hidden_sizes:
layers.append(nn.Linear(last, h))
layers.append(nn.Tanh())
last = h
self.net = nn.Sequential(*layers)
self.logits = nn.Linear(last, act_dim)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.logits(self.net(obs))
class Critic(nn.Module):
def __init__(self, obs_dim: int, hidden_sizes=(64, 64)):
super().__init__()
layers = []
last = obs_dim
for h in hidden_sizes:
layers.append(nn.Linear(last, h))
layers.append(nn.Tanh())
last = h
self.net = nn.Sequential(*layers)
self.v = nn.Linear(last, 1)
def forward(self, obs: torch.Tensor) -> torch.Tensor:
return self.v(self.net(obs)).squeeze(-1)
actor = Actor(obs_dim, act_dim).to(DEVICE)
critic = Critic(obs_dim).to(DEVICE)
critic_optim = torch.optim.Adam(critic.parameters(), lr=CRITIC_LR)
print('actor params', sum(p.numel() for p in actor.parameters()))
print('critic params', sum(p.numel() for p in critic.parameters()))
actor params 4610
critic params 4545
3) Rollouts + GAE#
We collect on-policy rollouts of length \(T\) and compute generalized advantage estimation (GAE):
with \(d_t \in \{0,1\}\) indicating episode termination.
def compute_gae(rewards, values, dones, last_value, gamma: float, lam: float):
"""NumPy GAE for a single rollout segment."""
T = len(rewards)
advantages = np.zeros(T, dtype=np.float32)
gae = 0.0
for t in reversed(range(T)):
next_value = last_value if t == T - 1 else values[t + 1]
next_nonterminal = 1.0 - dones[t]
delta = rewards[t] + gamma * next_value * next_nonterminal - values[t]
gae = delta + gamma * lam * next_nonterminal * gae
advantages[t] = gae
returns = advantages + values
return advantages, returns
4) K-FAC optimizer (Linear layers)#
ACKTR replaces a vanilla gradient step with a (preconditioned) natural gradient step.
For the policy parameters \(\theta\), the natural gradient direction is:
K-FAC approximates \(F\) block-wise per layer using Kronecker factors:
For a linear layer, this yields the matrix-form update (with damping):
We also apply a trust-region-style scaling so the policy does not change too much.
class KFACOptimizer:
"""Minimal K-FAC for nn.Linear modules (actor only).
- Collects factor stats (A,G) via forward/backward hooks on a Fisher-like loss.
- Preconditions parameter gradients with G^{-1} @ grad @ A^{-1}.
- Scales the step using a trust-region clip.
"""
def __init__(
self,
model: nn.Module,
lr: float,
damping: float,
stats_decay: float,
kfac_clip: float,
inverse_update_interval: int = 1,
):
self.model = model
self.lr = float(lr)
self.damping = float(damping)
self.stats_decay = float(stats_decay)
self.kfac_clip = float(kfac_clip)
self.inverse_update_interval = int(inverse_update_interval)
self._collect_stats = False
self._step = 0
self.modules = []
self.state = {}
for module in self.model.modules():
if isinstance(module, nn.Linear):
self.modules.append(module)
self.state[module] = {
'A': None,
'G': None,
'A_inv': None,
'G_inv': None,
}
module.register_forward_hook(self._forward_hook)
module.register_full_backward_hook(self._backward_hook)
def set_collect_stats(self, collect: bool):
self._collect_stats = bool(collect)
def _forward_hook(self, module, inputs, output):
if not self._collect_stats:
return
module._kfac_input = inputs[0].detach()
def _backward_hook(self, module, grad_input, grad_output):
if not self._collect_stats:
return
module._kfac_grad_output = grad_output[0].detach()
@torch.no_grad()
def update_stats(self):
for module in self.modules:
if not hasattr(module, '_kfac_input') or not hasattr(module, '_kfac_grad_output'):
continue
a = module._kfac_input
g = module._kfac_grad_output
if a.dim() != 2 or g.dim() != 2:
continue
batch = a.shape[0]
ones = torch.ones(batch, 1, device=a.device, dtype=a.dtype)
a_aug = torch.cat([a, ones], dim=1)
A_new = (a_aug.t() @ a_aug) / batch
G_new = (g.t() @ g) / batch
st = self.state[module]
if st['A'] is None:
st['A'] = A_new
st['G'] = G_new
else:
d = self.stats_decay
st['A'] = d * st['A'] + (1 - d) * A_new
st['G'] = d * st['G'] + (1 - d) * G_new
self._step += 1
if self._step % self.inverse_update_interval == 0:
self._update_inverses()
@torch.no_grad()
def _update_inverses(self):
for module in self.modules:
st = self.state[module]
if st['A'] is None or st['G'] is None:
continue
A = st['A'] + self.damping * torch.eye(st['A'].shape[0], device=st['A'].device, dtype=st['A'].dtype)
G = st['G'] + self.damping * torch.eye(st['G'].shape[0], device=st['G'].device, dtype=st['G'].dtype)
st['A_inv'] = torch.linalg.inv(A)
st['G_inv'] = torch.linalg.inv(G)
@torch.no_grad()
def step(self):
eps = 1e-8
shs = 0.0 # proxy for g^T F^{-1} g
updates = []
for module in self.modules:
st = self.state[module]
if st['A_inv'] is None or st['G_inv'] is None:
continue
if module.weight.grad is None:
continue
if module.bias is None or module.bias.grad is None:
continue
grad_w = module.weight.grad
grad_b = module.bias.grad
grad_wb = torch.cat([grad_w, grad_b.unsqueeze(1)], dim=1)
nat_wb = st['G_inv'] @ grad_wb @ st['A_inv']
nat_w = nat_wb[:, :-1]
nat_b = nat_wb[:, -1]
shs += float((grad_w * nat_w).sum().item() + (grad_b * nat_b).sum().item())
updates.append((module.weight, nat_w))
updates.append((module.bias, nat_b))
# Trust-region / KL clip: only scale down.
# (Theory: predicted KL \approx 0.5 * alpha^2 * g^T F^{-1} g)
nu = 1.0
if shs > 0:
predicted_kl = 0.5 * shs
nu = float(min(1.0, np.sqrt(self.kfac_clip / (predicted_kl + eps))))
else:
predicted_kl = 0.0
for param, nat_grad in updates:
param.add_(nat_grad, alpha=-self.lr * nu)
return {
'shs': shs,
'predicted_kl': predicted_kl,
'nu': nu,
}
actor_kfac = KFACOptimizer(
actor,
lr=ACTOR_LR,
damping=KFAC_DAMPING,
stats_decay=KFAC_STATS_DECAY,
kfac_clip=KFAC_CLIP,
inverse_update_interval=INVERSE_UPDATE_INTERVAL,
)
5) Training loop (ACKTR update)#
Each update:
Collect \(T\) on-policy transitions.
Compute \(\hat A_t\) and \(\hat R_t\) with GAE.
Update critic by minimizing \(\mathcal{L}_V\).
For the actor:
build a Fisher-like loss (to collect K-FAC stats)
backprop that loss to update \(A\) and \(G\)
backprop the policy loss and take a K-FAC-preconditioned step
We log:
episodic returns
actor loss, critic loss, entropy
estimated KL (before/after update)
trust-region scale factor \(\nu\)
num_updates = TOTAL_TIMESTEPS // ROLLOUT_STEPS
print('num_updates', num_updates)
obs, _ = env.reset(seed=SEED)
episode_return = 0.0
episode_len = 0
episode_returns = []
episode_lengths = []
logs = []
start = time.time()
for update in range(1, num_updates + 1):
# --- Rollout buffers ---
obs_buf = np.zeros((ROLLOUT_STEPS, obs_dim), dtype=np.float32)
act_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.int64)
rew_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
done_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
val_buf = np.zeros((ROLLOUT_STEPS,), dtype=np.float32)
for t in range(ROLLOUT_STEPS):
obs_buf[t] = obs
obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
with torch.no_grad():
logits = actor(obs_t)
dist = Categorical(logits=logits)
action = dist.sample()
value = critic(obs_t)
next_obs, reward, terminated, truncated, _ = env.step(int(action.item()))
done = bool(terminated or truncated)
act_buf[t] = int(action.item())
rew_buf[t] = float(reward)
done_buf[t] = float(done)
val_buf[t] = float(value.item())
episode_return += float(reward)
episode_len += 1
obs = next_obs
if done:
episode_returns.append(episode_return)
episode_lengths.append(episode_len)
episode_return = 0.0
episode_len = 0
obs, _ = env.reset()
with torch.no_grad():
if done_buf[-1] == 1.0:
last_value = 0.0
else:
last_obs_t = torch.tensor(obs, dtype=torch.float32, device=DEVICE).unsqueeze(0)
last_value = float(critic(last_obs_t).item())
advantages, returns = compute_gae(
rewards=rew_buf,
values=val_buf,
dones=done_buf,
last_value=last_value,
gamma=GAMMA,
lam=GAE_LAMBDA,
)
obs_batch = torch.tensor(obs_buf, dtype=torch.float32, device=DEVICE)
act_batch = torch.tensor(act_buf, dtype=torch.int64, device=DEVICE)
adv_batch = torch.tensor(advantages, dtype=torch.float32, device=DEVICE)
ret_batch = torch.tensor(returns, dtype=torch.float32, device=DEVICE)
adv_batch = (adv_batch - adv_batch.mean()) / (adv_batch.std() + 1e-8)
# --- Critic update (first-order) ---
critic_optim.zero_grad(set_to_none=True)
v_pred = critic(obs_batch)
critic_loss = 0.5 * (ret_batch - v_pred).pow(2).mean()
critic_loss.backward()
critic_optim.step()
# --- Actor update (ACKTR-style via K-FAC) ---
actor_kfac.set_collect_stats(True)
logits_old = actor(obs_batch).detach()
dist_old = Categorical(logits=logits_old)
logits = actor(obs_batch)
dist = Categorical(logits=logits)
logp = dist.log_prob(act_batch)
entropy = dist.entropy().mean()
actor_loss = -(logp * adv_batch.detach()).mean() - ENT_COEF * entropy
# Fisher-like loss: E[-log pi(a|s)]
fisher_loss = -logp.mean()
actor.zero_grad(set_to_none=True)
fisher_loss.backward(retain_graph=True)
actor_kfac.set_collect_stats(False)
actor_kfac.update_stats()
actor.zero_grad(set_to_none=True)
actor_loss.backward()
step_info = actor_kfac.step()
with torch.no_grad():
logits_new = actor(obs_batch)
dist_new = Categorical(logits=logits_new)
approx_kl = torch.distributions.kl_divergence(dist_old, dist_new).mean().item()
logs.append(
{
'update': update,
'timesteps': update * ROLLOUT_STEPS,
'episodes': len(episode_returns),
'actor_loss': float(actor_loss.item()),
'critic_loss': float(critic_loss.item()),
'entropy': float(entropy.item()),
'approx_kl': float(approx_kl),
**step_info,
}
)
if update % 25 == 0:
recent = episode_returns[-20:]
mean_20 = float(np.mean(recent)) if recent else float('nan')
elapsed = time.time() - start
print(
f'update {update:4d}/{num_updates} | episodes {len(episode_returns):4d} '
f'| mean_return_20 {mean_20:7.2f} | kl {approx_kl:9.2e} | nu {step_info["nu"]:7.3f} '
f'| elapsed {elapsed:6.1f}s'
)
env.close()
num_updates 312
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/autograd/graph.py:824: UserWarning:
CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
update 25/312 | episodes 56 | mean_return_20 93.55 | kl 4.40e-03 | nu 0.062 | elapsed 0.6s
update 50/312 | episodes 69 | mean_return_20 180.05 | kl 3.59e-02 | nu 0.066 | elapsed 1.1s
update 75/312 | episodes 87 | mean_return_20 209.00 | kl 5.94e-02 | nu 0.044 | elapsed 1.7s
update 100/312 | episodes 104 | mean_return_20 174.20 | kl 3.60e-02 | nu 0.036 | elapsed 2.3s
update 125/312 | episodes 121 | mean_return_20 178.80 | kl 3.47e-02 | nu 0.030 | elapsed 2.8s
update 150/312 | episodes 140 | mean_return_20 159.80 | kl 3.95e-02 | nu 0.044 | elapsed 3.4s
update 175/312 | episodes 159 | mean_return_20 172.60 | kl 1.87e-02 | nu 0.040 | elapsed 4.0s
update 200/312 | episodes 173 | mean_return_20 216.50 | kl 6.78e-02 | nu 0.019 | elapsed 4.6s
update 225/312 | episodes 192 | mean_return_20 156.50 | kl 9.28e-02 | nu 0.197 | elapsed 5.2s
update 250/312 | episodes 211 | mean_return_20 186.60 | kl 3.06e-02 | nu 0.049 | elapsed 5.7s
update 275/312 | episodes 245 | mean_return_20 90.65 | kl 1.04e-02 | nu 0.063 | elapsed 6.4s
update 300/312 | episodes 273 | mean_return_20 118.30 | kl 2.05e-02 | nu 0.048 | elapsed 6.9s
6) Plotly: learning dynamics#
We visualize:
episodic reward progression (raw + smoothed)
estimated KL per update
actor/critic losses
trust-region scaling factor \(\nu\)
df_logs = pd.DataFrame(logs)
df_eps = pd.DataFrame({'episode': np.arange(1, len(episode_returns) + 1), 'return': episode_returns})
df_eps['return_smooth'] = df_eps['return'].rolling(window=20, min_periods=1).mean()
fig = go.Figure()
fig.add_trace(go.Scatter(x=df_eps['episode'], y=df_eps['return'], mode='lines', name='return', line=dict(width=1)))
fig.add_trace(go.Scatter(x=df_eps['episode'], y=df_eps['return_smooth'], mode='lines', name='return (20-ep mean)', line=dict(width=3)))
fig.update_layout(
title='Episodic reward progression (CartPole-v1)',
xaxis_title='Episode',
yaxis_title='Return',
height=420,
)
fig.show()
fig2 = px.line(df_logs, x='timesteps', y=['approx_kl', 'predicted_kl'], title='KL diagnostics per update')
fig2.update_layout(height=380)
fig2.show()
fig3 = px.line(df_logs, x='timesteps', y=['actor_loss', 'critic_loss'], title='Losses per update')
fig3.update_layout(height=380)
fig3.show()
fig4 = px.line(df_logs, x='timesteps', y=['nu'], title='Trust-region scaling (nu)')
fig4.update_layout(height=320)
fig4.show()
7) Stable-Baselines ACKTR (reference)#
We’ll include a reference snippet for the (TensorFlow-based) Stable-Baselines implementation of ACKTR, plus an explanation of its key hyperparameters.
This section is reference only — the implementation above is the main deliverable.
Stable-Baselines usage (snippet)#
# pip install stable-baselines==2.* (TensorFlow 1.x based)
from stable_baselines import ACKTR
model = ACKTR(
policy='MlpPolicy',
env='CartPole-v1',
n_steps=20,
gamma=0.99,
ent_coef=0.01,
vf_coef=0.25,
vf_fisher_coef=1.0,
learning_rate=0.25,
max_grad_norm=0.5,
kfac_clip=0.001,
lr_schedule='linear',
kfac_update=1,
gae_lambda=None,
verbose=1,
)
model.learn(total_timesteps=200_000)
Hyperparameters (Stable-Baselines) explained#
Stable-Baselines (“v2”, TensorFlow 1.x) includes an ACKTR implementation (see stable_baselines/acktr/acktr.py). The constructor signature is:
ACKTR(
policy,
env,
gamma=0.99,
n_steps=20,
ent_coef=0.01,
vf_coef=0.25,
vf_fisher_coef=1.0,
learning_rate=0.25,
max_grad_norm=0.5,
kfac_clip=0.001,
lr_schedule='linear',
async_eigen_decomp=False,
kfac_update=1,
gae_lambda=None,
policy_kwargs=None,
seed=None,
n_cpu_tf_sess=1,
# + logging/boilerplate args
)
Core RL knobs
gamma: discount factor.n_steps: rollout length per environment before each update.gae_lambda: if notNone, Stable-Baselines computes GAE with parameter \(\lambda\); ifNone, it uses the classic advantage (no GAE).ent_coef: entropy bonus weight (encourages exploration).vf_coef: value loss weight in the joint loss.
ACKTR / K-FAC + trust region knobs
kfac_clip: KL-based clip used inside the K-FAC optimizer (trust-region-like safeguard; calledclip_klin the underlying optimizer).vf_fisher_coef: weight on the value-function Fisher loss. In the Stable-Baselines code, the value Fisher is constructed by adding noise to the value output and backpropagating a Gaussian negative log-likelihood; this lets K-FAC build curvature stats for the critic.learning_rate: the step size used by the K-FAC optimizer (and scheduled bylr_schedule).lr_schedule: learning-rate schedule string ('linear','constant','double_linear_con','middle_drop','double_middle_drop').kfac_update: update frequency for K-FAC statistics / eigen decompositions.async_eigen_decomp: compute eigen decompositions asynchronously (speed/throughput trade-off).max_grad_norm: global gradient clipping.
Practical / reproducibility knobs
policy: policy network type (e.g.MlpPolicy,CnnPolicy,CnnLstmPolicy).env: Gym env instance or env id string.policy_kwargs: extra arguments forwarded to the policy.seed: seeds python/NumPy/TensorFlow RNGs.n_cpu_tf_sess: TensorFlow thread count (for determinism, set this to1).
Note: Stable-Baselines wires ACKTR into kfac.KfacOptimizer(...) with additional internal defaults (e.g. momentum=0.9, epsilon=0.01, stats_decay=0.99, cold_iter=10).